{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import chromadb\n",
    "from chromadb.config import Settings\n",
    "from transformers import AutoTokenizer, AutoModel\n",
    "import torch\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义默认常量\n",
    "DEFAULT_TENANT = \"default_tenant\"\n",
    "DEFAULT_DATABASE = \"default_database\"\n",
    "\n",
    "\n",
    "# 创建 PersistentClient\n",
    "client = chromadb.PersistentClient(\n",
    "    path=\"../datasets/chroma_db\",\n",
    "    #settings=Settings(chroma_db_impl=\"duckdb+parquet\"),\n",
    "    tenant=DEFAULT_TENANT,\n",
    "    database=DEFAULT_DATABASE\n",
    ")\n",
    "\n",
    "# 创建或获取集合\n",
    "collection_name = \"chat2db\"\n",
    "collection = client.get_collection(name=collection_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 加载预训练模型和分词器\n",
    "model_name = \"sentence-transformers/all-roberta-large-v1\"  # 选择一个适合的embed预训练模型\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "model = AutoModel.from_pretrained(model_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_embedding(text: str) -> np.ndarray:\n",
    "    \"\"\"将文本转换为嵌入向量\"\"\"\n",
    "    inputs = tokenizer(text, return_tensors=\"pt\", padding=True, truncation=True)\n",
    "    with torch.no_grad():\n",
    "        outputs = model(**inputs)\n",
    "    embeddings = outputs.last_hidden_state.mean(dim=1).squeeze().numpy()\n",
    "    return embeddings\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 查询操作：传入字符串并将其转换为嵌入向量\n",
    "query_text = \"\"\"公司董事长兼总经理赵马克、董事兼副总经理胡戎、董事胡\n",
    "斌、董事胡志斌、监事胡涛、副总经理兼财务总监吴伟钢、副总\n",
    "经理易国平、副总经理骆敏健、副总经理戴大盛、副总经理兼董\n",
    "事会秘书王晓东承诺：上述承诺的限售期届满后，\"\"\"\n",
    "query_vector = get_embedding(query_text)\n",
    "\n",
    "results = collection.query(\n",
    "    query_embeddings=[query_vector],\n",
    "    n_results=1\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'ids': [['18b899b1-cd8b-45f4-9db3-8ece917a6eb4']],\n",
       " 'embeddings': None,\n",
       " 'documents': [['图6-4 \\n \\n与同类技术相比，海默科技多相计量技术最重要的创新点在于： \\n(1) 采用了气液两相流量测量和油气水三相含水率独立进行的技术路线，从 \\n而确保了气液流量测量和油气水三相含水率测量分别可以采用最优化的方法进\\n行，流量测量的量程和精度与含水率测量的量程和精度之间相互独立，彼此不受']],\n",
       " 'uris': None,\n",
       " 'data': None,\n",
       " 'metadatas': [[{'author': 'rd-03',\n",
       "    'creationDate': \"D:20100427010927+08'00'\",\n",
       "    'creator': 'PScript5.dll Version 5.2.2',\n",
       "    'file_path': 'D:\\\\AIClass\\\\chat2money\\\\datasets\\\\pdf\\\\d6867ff3ecf9882065e45612eaf7b33ccb9a95dc.PDF',\n",
       "    'format': 'PDF 1.4',\n",
       "    'keywords': '',\n",
       "    'modDate': \"D:20100427010927+08'00'\",\n",
       "    'page': 86,\n",
       "    'producer': 'Acrobat Distiller 7.0 (Windows)',\n",
       "    'source': 'D:\\\\AIClass\\\\chat2money\\\\datasets\\\\pdf\\\\d6867ff3ecf9882065e45612eaf7b33ccb9a95dc.PDF',\n",
       "    'start_index': 406,\n",
       "    'subject': '',\n",
       "    'title': '<4D6963726F736F667420576F7264202D20BAA3C4ACBFC6BCBCA3BACAD7B4CEB9ABBFAAB7A2D0D0B9C9C6B1B2A2D4DAB4B4D2B5B0E5C9CFCAD0D5D0B9C9D2E2CFF2CAE92E646F63>',\n",
       "    'total_pages': 315,\n",
       "    'trapped': ''}]],\n",
       " 'distances': [[415.4557800292969]],\n",
       " 'included': [<IncludeEnum.distances: 'distances'>,\n",
       "  <IncludeEnum.documents: 'documents'>,\n",
       "  <IncludeEnum.metadatas: 'metadatas'>]}"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "chat2money",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
